{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From d:\\software\\anaconda3\\envs\\bdmi\\lib\\site-packages\\tensorflow_core\\python\\compat\\v2_compat.py:65: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "non-resource variables are not supported in the long term\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow.compat.v1 as tf\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "tf.disable_v2_behavior()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = open('min-char-rnn-tensorflow.py', 'r').read()  # should be simple plain text file\n",
    "chars = list(set(data))\n",
    "data_size, vocab_size = len(data), len(chars)\n",
    "\n",
    "char2idx = { ch:i for i, ch in enumerate(chars)}\n",
    "idx2char = np.array(chars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_raw_data():\n",
    "    data_as_int = np.array(list(map(char2idx.get,data)))\n",
    "    return data_as_int[0:-1],data_as_int[1:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_raw_data()\n",
    "\n",
    "state_size = 100\n",
    "batch_size = 5\n",
    "seq_length = 25\n",
    "learning_rate = 1e-1\n",
    "\n",
    "def get_batch_seq(data):\n",
    "    raw_x, raw_y = data\n",
    "    batch_partition_length = len(raw_x) // batch_size\n",
    "#     print(batch_partition_length)\n",
    "#     print(raw_x[:-(len(raw_x)%batch_partition_length)])\n",
    "    data_x=raw_x[:-(len(raw_x)%batch_partition_length)].reshape(-1,batch_partition_length)\n",
    "    data_y=raw_y[:-(len(raw_x)%batch_partition_length)].reshape(-1,batch_partition_length)\n",
    "    \n",
    "    epoch_steps = batch_partition_length // seq_length\n",
    "    for step in range(epoch_steps):        \n",
    "        x = data_x[:, step*seq_length:(step+1)*seq_length]\n",
    "        y = data_y[:, step*seq_length:(step+1)*seq_length]\n",
    "        yield x,y           \n",
    "\n",
    "def get_epoch(n):\n",
    "    for i in range(n):\n",
    "        yield get_batch_seq(get_raw_data())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharRnnModel():\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def create_compute_graph(self):\n",
    "\n",
    "        with tf.variable_scope(str(id(self)) + 'rnn_cell'):\n",
    "            w = tf.get_variable('w',[vocab_size + state_size, state_size])\n",
    "            b = tf.get_variable('b',[state_size])\n",
    "\n",
    "        def rnn_cell(rnn_input,pre_state):\n",
    "\n",
    "            with tf.variable_scope(str(id(self)) + 'rnn_cell',reuse=True):\n",
    "                w = tf.get_variable('w',[vocab_size + state_size, state_size])\n",
    "                b = tf.get_variable('b',[state_size])\n",
    "            return tf.tanh(tf.matmul(tf.concat([rnn_input,pre_state],axis=1), w) + b)\n",
    "        \n",
    "        # def create_compute_graph():\n",
    "        x = tf.placeholder(tf.int32, [None, seq_length])\n",
    "        y = tf.placeholder(tf.int32, [None, seq_length])\n",
    "        init_state = tf.placeholder(tf.float32,[None, state_size])\n",
    "\n",
    "        x_one_hot = tf.one_hot(x,vocab_size)\n",
    "        y_one_hot = tf.one_hot(y,vocab_size)\n",
    "\n",
    "        rnn_inputs = tf.unstack(x_one_hot,axis=1)\n",
    "        rnn_labels = tf.unstack(y_one_hot,axis=1)\n",
    "\n",
    "        state = init_state\n",
    "        rnn_outputs = []\n",
    "        for rnn_input in rnn_inputs:\n",
    "            state = rnn_cell(rnn_input, state)\n",
    "            rnn_outputs.append(state)\n",
    "        final_state = state\n",
    "\n",
    "        with tf.variable_scope(str(id(self)) + 'softmax'):\n",
    "            w = tf.get_variable('w',[state_size, vocab_size])\n",
    "            b = tf.get_variable('b',[vocab_size])\n",
    "\n",
    "        logits = [tf.matmul(rnn_output, w) + b for rnn_output in rnn_outputs]\n",
    "        #predictions = [tf.nn.softmax(logit) for logit in logits]\n",
    "\n",
    "        losses = [tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logit) \\\n",
    "                  for logit, label in zip(logits, rnn_labels)]\n",
    "        total_loss = tf.reduce_mean(losses)\n",
    "        update = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)\n",
    "        \n",
    "        return x,y,init_state,final_state,total_loss,update\n",
    "    \n",
    "    def train(self,num_epochs):\n",
    "        x,y,init_state,final_state,total_loss,update = self.create_compute_graph()\n",
    "        \n",
    "        self.sess.run(tf.global_variables_initializer())\n",
    "        training_losses=[]\n",
    "        e_index = 0\n",
    "        for epoch in get_epoch(num_epochs):\n",
    "            \n",
    "            training_loss = 0\n",
    "            training_state = np.zeros((batch_size, state_size))\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                training_loss_, training_state, _ = self.sess.run([total_loss,final_state,update],\n",
    "                                                             feed_dict={x:X, y:Y, init_state:training_state})\n",
    "                training_loss += training_loss_\n",
    "            if e_index % 10 == 0 and e_index > 0:\n",
    "                print(\"Average loss at epoch\", e_index,\n",
    "                      \"for last 10 epochs:\", training_loss/100)\n",
    "                training_losses.append(training_loss/100)\n",
    "                training_loss = 0\n",
    "            e_index+=1\n",
    "        print('train finished')\n",
    "        return training_losses\n",
    "    \n",
    "    def create_test_graph(self):\n",
    "        x = tf.placeholder(tf.int32,[1])\n",
    "        x_one_hot = tf.one_hot(x,vocab_size)\n",
    "        init_state = tf.placeholder(tf.float32,[1,state_size])\n",
    "        \n",
    "        with tf.variable_scope(str(id(self)) + 'rnn_cell',reuse=True):\n",
    "            w = tf.get_variable('w',[vocab_size + state_size, state_size])\n",
    "            b = tf.get_variable('b',[state_size])\n",
    "            \n",
    "        state = tf.tanh(tf.matmul(tf.concat([x_one_hot,init_state],axis=1),w) + b)\n",
    "        \n",
    "        with tf.variable_scope(str(id(self)) + 'softmax', reuse=True):\n",
    "            w2 = tf.get_variable('w',[state_size, vocab_size])\n",
    "            b2 = tf.get_variable('b',[vocab_size])\n",
    "        y = tf.matmul(state,w2) + b2\n",
    "        p = tf.nn.softmax(y)\n",
    "        out = tf.argmax(p,axis=1) \n",
    "        return x, init_state,state, out\n",
    "    \n",
    "    def sample(self,n):\n",
    "        x, init_state, state,out = self.create_test_graph()\n",
    "        test_x = np.array([char2idx.get(data[0])])\n",
    "        training_state = np.zeros([1,state_size])\n",
    "        result = []\n",
    "        for i in range(n):\n",
    "            result.append(test_x[0])\n",
    "            training_state,test_x = self.sess.run([state,out],feed_dict = {x:test_x, \n",
    "                                                                           init_state:training_state})\n",
    "        return \"\".join(list(map(lambda x:idx2char[x],result)))\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start traning...\n",
      "WARNING:tensorflow:From d:\\software\\anaconda3\\envs\\bdmi\\lib\\site-packages\\tensorflow_core\\python\\training\\adagrad.py:76: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "Average loss at epoch 10 for last 10 epochs: 0.7381861710548401\n",
      "Average loss at epoch 20 for last 10 epochs: 0.5140705299377442\n",
      "Average loss at epoch 30 for last 10 epochs: 0.3861145704984665\n",
      "Average loss at epoch 40 for last 10 epochs: 0.3100062629580498\n",
      "Average loss at epoch 50 for last 10 epochs: 0.246302972137928\n",
      "Average loss at epoch 60 for last 10 epochs: 0.2039945475757122\n",
      "Average loss at epoch 70 for last 10 epochs: 0.17008764773607254\n",
      "Average loss at epoch 80 for last 10 epochs: 0.1492837519943714\n",
      "Average loss at epoch 90 for last 10 epochs: 0.13327353641390802\n",
      "train finished\n",
      "start testing...\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'# lodid             rnite_date)\\n    brrat_camistep,in srate, bo sel(btes)\\n\\n                         '"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = CharRnnModel()   \n",
    "print('start traning...')\n",
    "\n",
    "model.train(100)\n",
    "print('start testing...')\n",
    "model.sample(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
